{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何减小显存占用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1 减小数据量(筛选高质量数据)\n",
    "2 减小per_device_train_batch_size ✔\n",
    "3 增大gradient_accumulation_steps ✔\n",
    "4 Lora 秩减小(8，16，32) ✔\n",
    "5 target_modules里面的参数减小，一般是[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]\n",
    "6 bfloat16/fp16 导入模型 ✔\n",
    "7 optim = \"adamw_8bit\"· \n",
    "8 z2-no-offload -> z2-offload ->z3-no-offload -> z3-offload(显存占用依次减小，训练速度依次降低) ✔\n",
    "9 8bit量化 4bit量化\n",
    "10 限制数据长度（截断）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4bit量化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "latex"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
    "import torch\n",
    "\n",
    "\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,  # 计算时使用 bfloat16\n",
    "    bnb_4bit_use_double_quant=True,         # 双层量化，进一步压缩\n",
    "    bnb_4bit_quant_type=\"nf4\",              # 量化类型：nf4 或 fp4\n",
    ")\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    quantization_config=bnb_config,  # 关键参数\n",
    "    device_map=\"auto\",               # 自动分配设备\n",
    "    torch_dtype=torch.bfloat16,      # 模型计算数据类型\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8bit量化\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "latex"
    }
   },
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    load_in_8bit=True,          # 关键参数\n",
    "    device_map=\"auto\",          # 自动分配设备\n",
    "    torch_dtype=torch.bfloat16,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
